/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "op_dtype_selection_strategy_allow_fp32_to_fp16.h"

namespace fe {
OpDtypeSelectionStrategyAllowFp32ToFp16::OpDtypeSelectionStrategyAllowFp32ToFp16(
    FormatDtypeQuerierPtr format_dtype_querier_ptr, OpDtypeRiseMatcherPtr op_dtype_rise_matcher_ptr,
    OpDtypeReduceMatcherPtr op_dtype_reduce_matcher_ptr)
    : OpDtypeSeletionStrategyBase(format_dtype_querier_ptr),
      op_dtype_rise_matcher_ptr_(op_dtype_rise_matcher_ptr),
      op_dtype_reduce_matcher_ptr_(op_dtype_reduce_matcher_ptr) {}

OpDtypeSelectionStrategyAllowFp32ToFp16::~OpDtypeSelectionStrategyAllowFp32ToFp16() {}

Status OpDtypeSelectionStrategyAllowFp32ToFp16::Run(SelectionBasicInfo &basic_info) {
  FE_CHECK_NOTNULL(basic_info.node);
  auto cur_op_desc_ptr = basic_info.node->GetOpDesc();
  FE_CHECK_NOTNULL(cur_op_desc_ptr);
  std::string cur_op_desc_name = cur_op_desc_ptr->GetName();
  std::string cur_op_desc_type = cur_op_desc_ptr->GetType();
  FE_LOGD("Op[name=%s,type=%s]: match dtype for tensor %u in AllowFp32ToFp16.", cur_op_desc_name.c_str(),
          cur_op_desc_type.c_str(), basic_info.index);

  ge::DataType origin_dtype = basic_info.cur_tensor_desc->GetDataType();
  vector<ge::DataType> input_or_output_dtype_vec;
  if (format_dtype_querier_ptr_->GetSupportDataTypes(basic_info.op_kernel_info_ptr, basic_info.tensor_kernel_info_ptr,
                                                     *(cur_op_desc_ptr.get()), input_or_output_dtype_vec) != SUCCESS) {
    REPORT_FE_ERROR("[GraphOpt][DtypeJdg][AllowFp32ToFp16][Op %s type %s] Fail to get the support data_types.",
                    cur_op_desc_name.c_str(), cur_op_desc_type.c_str());
    return FAILED;
  }
  FE_LOGD("Op[name=%s,type=%s]: match the origin dtype, the expected dtype is %u.", cur_op_desc_name.c_str(),
          cur_op_desc_type.c_str(), origin_dtype);
  // 1.match datatype with origin datatype using increasing mode, in this mode
  // we will first ensure the precision will not decrease.
  Status match_origin_dtype_res =
      op_dtype_rise_matcher_ptr_->Match(input_or_output_dtype_vec, origin_dtype, basic_info.matched_index_vec);
  if (match_origin_dtype_res != SUCCESS && origin_dtype == ge::DT_FLOAT) {
    // 1.match datatype with origin datatype using reducing mode, in this mode
    // we will allow the precision reduce from fp32 to fp16
    FE_LOGD("Precision loss is allowed, try to match low precision dtype.");
    match_origin_dtype_res =
        op_dtype_reduce_matcher_ptr_->Match(input_or_output_dtype_vec, origin_dtype, basic_info.matched_index_vec);
  }

  if (match_origin_dtype_res == SUCCESS) {
    FE_LOGD("Op[name=%s,type=%s]: match the origin dtype success, some matched dtype in op kernel have been found.",
        cur_op_desc_name.c_str(), cur_op_desc_type.c_str());
    FE_LOGD("The size of input_or_output_dtype_vec is %zu, the size of matchedIndexVec is %zu.",
        input_or_output_dtype_vec.size(), basic_info.matched_index_vec.size());
  } else {
    FE_LOGD(
        "Op[name=%s,type=%s]: no matched the origin dtype, "
        "matchedIndexVec remain the same.",
        cur_op_desc_name.c_str(), cur_op_desc_type.c_str());
  }
  return SUCCESS;
}
}  // namespace fe
